{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SD Trainer Kaggle\n",
    "#### Created by [licyk](https://github.com/licyk)\n",
    "\n",
    "Jupyter Notebook 仓库：[licyk/sd-webui-all-in-one](https://github.com/licyk/sd-webui-all-in-one)\n",
    "\n",
    "一个在 [Kaggle](https://www.kaggle.com) 部署 [SD Trainer](https://github.com/Akegarasu/lora-scripts) 的 Jupyter Notebook。\n",
    "\n",
    "使用时请按顺序运行笔记单元。\n",
    "\n",
    "## 提示：\n",
    "1. 可以将训练数据上传至`kaggle/input`文件夹，运行安装时将会把训练数据放置`/kaggle/data`文件夹中。\n",
    "2. 训练需要的模型将下载至`/kaggle/lora-scripts/sd-models`文件夹中。\n",
    "3. 推荐将训练输出的模型路径改为`kaggle/working`文件夹，方便下载。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 参数配置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "INIT_CONFIG = 1\n",
    "\n",
    "# 消息格式输出\n",
    "def echo(*args):\n",
    "    for i in args:\n",
    "        print(f\":: {i}\")\n",
    "\n",
    "\n",
    "# 将/kaggle/input中的文件复制到/kaggle/data\n",
    "def cp_data():\n",
    "    import os\n",
    "    if not os.path.exists(\"/kaggle/data\"):\n",
    "        !mkdir -p /kaggle/data\n",
    "    data_list = os.listdir(\"/kaggle/input\")\n",
    "    for i in data_list:\n",
    "        file_path = os.path.join(\"/kaggle/input\", i)\n",
    "        !cp -rf {file_path} /kaggle/data\n",
    "\n",
    "\n",
    "# ARIA2\n",
    "class ARIA2:\n",
    "    WORKSPACE = \"\"\n",
    "    WORKFOLDER = \"\"\n",
    "\n",
    "\n",
    "    def __init__(self, workspace, workfolder) -> None:\n",
    "        self.WORKSPACE = workspace\n",
    "        self.WORKFOLDER = workfolder\n",
    "\n",
    "\n",
    "    # 下载器\n",
    "    def aria2(self, url, path, filename):\n",
    "        import os\n",
    "        if not os.path.exists(path + \"/\" + filename):\n",
    "            echo(f\"开始下载 {filename} ，路径: {path}/{filename}\")\n",
    "            !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M \"{url}\" -d \"{path}\" -o \"{filename}\"\n",
    "            if os.path.exists(path + \"/\" + filename) and not os.path.exists(path + \"/\" + filename + \".aria2\"):\n",
    "                echo(f\"{filename} 下载完成\")\n",
    "            else:\n",
    "                echo(f\"{filename} 下载中断\")\n",
    "        else:\n",
    "            if os.path.exists(path + \"/\" + filename + \".aria2\"):\n",
    "                echo(f\"开始下载 {filename} ，路径: {path}/{filename}\")\n",
    "                !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M \"{url}\" -d \"{path}\" -o \"{filename}\"\n",
    "                if os.path.exists(path + \"/\" + filename) and not os.path.exists(path + \"/\" + filename + \".aria2\"):\n",
    "                    echo(f\"{filename} 下载完成\")\n",
    "                else:\n",
    "                    echo(f\"{filename} 下载中断\")\n",
    "            else:\n",
    "                echo(f\"{filename} 文件已存在，路径: {path}/{filename}\")\n",
    "\n",
    "\n",
    "    # 大模型下载\n",
    "    def get_sd_model(self, url, filename):\n",
    "        pass\n",
    "\n",
    "\n",
    "    # vae模型下载\n",
    "    def get_vae_model(self, url, filename):\n",
    "        pass\n",
    "\n",
    "\n",
    "# GIT\n",
    "class GIT:\n",
    "    WORKSPACE = \"\"\n",
    "    WORKFOLDER = \"\"\n",
    "\n",
    "\n",
    "    def __init__(self, workspace, workfolder) -> None:\n",
    "        self.WORKSPACE = workspace\n",
    "        self.WORKFOLDER = workfolder\n",
    "\n",
    "\n",
    "    # 检测要克隆的项目是否存在于指定路径\n",
    "    def exists(self, addr=None, path=None, name=None):\n",
    "        import os\n",
    "        if addr is not None:\n",
    "            if path is None and name is None:\n",
    "                path = os.getcwd() + \"/\" + addr.split(\"/\").pop().split(\".git\", 1)[0]\n",
    "            elif path is None and name is not None:\n",
    "                path = os.getcwd() + \"/\" + name\n",
    "            elif path is not None and name is None:\n",
    "                path = os.path.normpath(path) + \"/\" + addr.split(\"/\").pop().split(\".git\", 1)[0]\n",
    "\n",
    "        if os.path.exists(path):\n",
    "            return True\n",
    "        else:\n",
    "            return False\n",
    "\n",
    "\n",
    "    # 克隆项目\n",
    "    def clone(self, addr, path=None, name=None):\n",
    "        import os\n",
    "        repo = addr.split(\"/\").pop().split(\".git\", 1)[0]\n",
    "        if not self.exists(addr, path, name):\n",
    "            echo(f\"开始下载 {repo}\")\n",
    "            if path is None and name is None:\n",
    "                path = os.getcwd()\n",
    "                name = repo\n",
    "            elif path is not None and name is None:\n",
    "                name = repo\n",
    "            elif path is None and name is not None:\n",
    "                path = os.getcwd()\n",
    "            !git clone {addr} \"{path}/{name}\" --recurse-submodules\n",
    "        else:\n",
    "            echo(f\"{repo} 已存在\")\n",
    "\n",
    "\n",
    "\n",
    "# TUNNEL\n",
    "class TUNNEL:\n",
    "    LOCALHOST_RUN = \"localhost.run\"\n",
    "    REMOTE_MOE = \"remote.moe\"\n",
    "    WORKSPACE = \"\"\n",
    "    WORKFOLDER = \"\"\n",
    "    PORT = \"\"\n",
    "\n",
    "\n",
    "    def __init__(self, workspace, workfolder, port) -> None:\n",
    "        self.WORKSPACE = workspace\n",
    "        self.WORKFOLDER = workfolder\n",
    "        self.PORT = port\n",
    "\n",
    "\n",
    "    # ngrok内网穿透\n",
    "    def ngrok(self, ngrok_token: str):\n",
    "        from pyngrok import conf, ngrok\n",
    "        conf.get_default().auth_token = ngrok_token\n",
    "        conf.get_default().monitor_thread = False\n",
    "        port = self.PORT\n",
    "        ssh_tunnels = ngrok.get_tunnels(conf.get_default())\n",
    "        if len(ssh_tunnels) == 0:\n",
    "            ssh_tunnel = ngrok.connect(port, bind_tls=True)\n",
    "            return ssh_tunnel.public_url\n",
    "        else:\n",
    "            return ssh_tunnels[0].public_url\n",
    "\n",
    "\n",
    "    # cloudflare内网穿透\n",
    "    def cloudflare(self):\n",
    "        from pycloudflared import try_cloudflare\n",
    "        port = self.PORT\n",
    "        urls = try_cloudflare(port).tunnel\n",
    "        return urls\n",
    "\n",
    "\n",
    "    from typing import Union\n",
    "    from pathlib import Path\n",
    "\n",
    "    # 生成ssh密钥\n",
    "    def gen_key(self, path: Union[str, Path]) -> None:\n",
    "        import subprocess\n",
    "        import shlex\n",
    "        from pathlib import Path\n",
    "        path = Path(path)\n",
    "        arg_string = f'ssh-keygen -t rsa -b 4096 -N \"\" -q -f {path.as_posix()}'\n",
    "        args = shlex.split(arg_string)\n",
    "        subprocess.run(args, check=True)\n",
    "        path.chmod(0o600)\n",
    "\n",
    "\n",
    "    # ssh内网穿透\n",
    "    def ssh_tunnel(self, host: str) -> None:\n",
    "        import subprocess\n",
    "        import atexit\n",
    "        import shlex\n",
    "        import re\n",
    "        import os\n",
    "        from pathlib import Path\n",
    "        from tempfile import TemporaryDirectory\n",
    "\n",
    "        ssh_name = \"id_rsa\"\n",
    "        ssh_path = Path(self.WORKSPACE) / ssh_name\n",
    "        port = self.PORT\n",
    "\n",
    "        tmp = None\n",
    "        if not ssh_path.exists():\n",
    "            try:\n",
    "                self.gen_key(ssh_path)\n",
    "            # write permission error or etc\n",
    "            except subprocess.CalledProcessError:\n",
    "                tmp = TemporaryDirectory()\n",
    "                ssh_path = Path(tmp.name) / ssh_name\n",
    "                self.gen_key(ssh_path)\n",
    "\n",
    "        arg_string = f\"ssh -R 80:127.0.0.1:{port} -o StrictHostKeyChecking=no -i {ssh_path.as_posix()} {host}\"\n",
    "        args = shlex.split(arg_string)\n",
    "\n",
    "        tunnel = subprocess.Popen(\n",
    "            args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding=\"utf-8\"\n",
    "        )\n",
    "\n",
    "        atexit.register(tunnel.terminate)\n",
    "        if tmp is not None:\n",
    "            atexit.register(tmp.cleanup)\n",
    "\n",
    "        tunnel_url = \"\"\n",
    "        LOCALHOST_RUN = self.LOCALHOST_RUN\n",
    "        lines = 27 if host == LOCALHOST_RUN else 5\n",
    "        localhostrun_pattern = re.compile(r\"(?P<url>https?://\\S+\\.lhr\\.life)\")\n",
    "        remotemoe_pattern = re.compile(r\"(?P<url>https?://\\S+\\.remote\\.moe)\")\n",
    "        pattern = localhostrun_pattern if host == LOCALHOST_RUN else remotemoe_pattern\n",
    "\n",
    "        for _ in range(lines):\n",
    "            line = tunnel.stdout.readline()\n",
    "            if line.startswith(\"Warning\"):\n",
    "                print(line, end=\"\")\n",
    "\n",
    "            url_match = pattern.search(line)\n",
    "            if url_match:\n",
    "                tunnel_url = url_match.group(\"url\")\n",
    "                if lines == 27:\n",
    "                    os.environ['LOCALHOST_RUN'] = tunnel_url\n",
    "                    return tunnel_url\n",
    "                else:\n",
    "                    os.environ['REMOTE_MOE'] = tunnel_url\n",
    "                    return tunnel_url\n",
    "                # break\n",
    "        else:\n",
    "            echo(f\"启动 {host} 内网穿透失败\")\n",
    "\n",
    "\n",
    "    # localhost.run穿透\n",
    "    def localhost_run(self):\n",
    "        urls = self.ssh_tunnel(self.LOCALHOST_RUN)\n",
    "        return urls\n",
    "\n",
    "\n",
    "    # remote.moe内网穿透\n",
    "    def remote_moe(self):\n",
    "        urls = self.ssh_tunnel(self.REMOTE_MOE)\n",
    "        return urls\n",
    "\n",
    "\n",
    "    # gradio内网穿透\n",
    "    def gradio(self):\n",
    "        import subprocess\n",
    "        import shlex\n",
    "        import atexit\n",
    "        import re\n",
    "        port = self.PORT\n",
    "        cmd = f\"gradio-tunneling --port {port}\"\n",
    "        cmd = shlex.split(cmd)\n",
    "        tunnel = subprocess.Popen(\n",
    "            cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding=\"utf-8\"\n",
    "        )\n",
    "\n",
    "        atexit.register(tunnel.terminate)\n",
    "\n",
    "        tunnel_url = \"\"\n",
    "        lines = 5\n",
    "        gradio_pattern = re.compile(r\"(?P<url>https?://\\S+\\.gradio\\.live)\")\n",
    "        pattern = gradio_pattern\n",
    "\n",
    "        for _ in range(lines):\n",
    "            line = tunnel.stdout.readline()\n",
    "            if line.startswith(\"Warning\"):\n",
    "                print(line, end=\"\")\n",
    "            url_match = pattern.search(line)\n",
    "            if url_match:\n",
    "                tunnel_url = url_match.group(\"url\")\n",
    "                return tunnel_url\n",
    "        else:\n",
    "            echo(f\"启动 Gradio 内网穿透失败\")\n",
    "\n",
    "\n",
    "    # 启动内网穿透\n",
    "    def start(self, ngrok=False, ngrok_token=None, cloudflare=False, remote_moe=False, localhost_run=False, gradio=False):\n",
    "        if cloudflare is True or ngrok is True or ngrok_token is not None or remote_moe is True or localhost_run is True or gradio is True:\n",
    "            echo(\"启动内网穿透\")\n",
    "\n",
    "        if cloudflare is True:\n",
    "            cloudflare_url = self.cloudflare()\n",
    "        else:\n",
    "            cloudflare_url = None\n",
    "\n",
    "        if ngrok is True and ngrok_token is not None:\n",
    "            ngrok_url = self.ngrok(ngrok_token)\n",
    "        else:\n",
    "            ngrok_url = None\n",
    "\n",
    "        if remote_moe is True:\n",
    "            remote_moe_url = self.remote_moe()\n",
    "        else:\n",
    "            remote_moe_url = None\n",
    "\n",
    "        if localhost_run is True:\n",
    "            localhost_run_url = self.localhost_run()\n",
    "        else:\n",
    "            localhost_run_url = None\n",
    "\n",
    "        if gradio is True:\n",
    "            gradio_url = self.gradio()\n",
    "        else:\n",
    "            gradio_url = None\n",
    "\n",
    "        echo(\"下方为访问地址\")\n",
    "        print(\"==================================================================================\")\n",
    "        echo(f\"CloudFlare: {cloudflare_url}\")\n",
    "        echo(f\"Ngrok: {ngrok_url}\")\n",
    "        echo(f\"remote.moe: {remote_moe_url}\")\n",
    "        echo(f\"localhost_run: {localhost_run_url}\")\n",
    "        echo(f\"Gradio: {gradio_url}\")\n",
    "        print(\"==================================================================================\")\n",
    "\n",
    "\n",
    "\n",
    "# ENV\n",
    "class ENV:\n",
    "    WORKSPACE = \"\"\n",
    "    WORKFOLDER = \"\"\n",
    "\n",
    "\n",
    "    def __init__(self, workspace, workfolder) -> None:\n",
    "        self.WORKSPACE = workspace\n",
    "        self.WORKFOLDER = workfolder\n",
    "\n",
    "\n",
    "    # 准备ipynb笔记自身功能的依赖\n",
    "    def prepare_env_depend(self, use_mirror=True):\n",
    "        if use_mirror is True:\n",
    "            pip_mirror = \"--index-url https://mirrors.cloud.tencent.com/pypi/simple --find-links https://mirror.sjtu.edu.cn/pytorch-wheels/cu121/torch_stable.html\"\n",
    "        else:\n",
    "            pip_mirror = \"--index-url https://pypi.python.org/simple --find-links https://download.pytorch.org/whl/cu121/torch_stable.html\"\n",
    "\n",
    "        echo(\"安装自身组件依赖\")\n",
    "        !pip install pyngrok pycloudflared gradio-tunneling {pip_mirror}\n",
    "        !apt update\n",
    "        !apt install aria2 ssh google-perftools -y\n",
    "\n",
    "\n",
    "    # 安装pytorch和xformers\n",
    "    def prepare_torch(self, torch_ver, xformers_ver, use_mirror=False):\n",
    "        arg = \"--no-deps\"\n",
    "        if use_mirror is True:\n",
    "            pip_mirror = \"--index-url https://mirrors.cloud.tencent.com/pypi/simple --find-links https://mirror.sjtu.edu.cn/pytorch-wheels/cu121/torch_stable.html\"\n",
    "        else:\n",
    "            pip_mirror = \"--index-url https://pypi.python.org/simple --find-links https://download.pytorch.org/whl/cu121/torch_stable.html\"\n",
    "        \n",
    "        if torch_ver != \"\":\n",
    "            echo(\"安装 PyTorch\")\n",
    "            !pip install {torch_ver} {pip_mirror}\n",
    "        if xformers_ver != \"\":\n",
    "            echo(\"安装 xFormers\")\n",
    "            !pip install {xformers_ver} {pip_mirror} {arg}\n",
    "    \n",
    "\n",
    "    # 安装requirements.txt依赖\n",
    "    def install_requirements(self, path, use_mirror=False):\n",
    "        import os\n",
    "        if use_mirror is True:\n",
    "            pip_mirror = \"--index-url https://mirrors.cloud.tencent.com/pypi/simple --find-links https://mirror.sjtu.edu.cn/pytorch-wheels/cu121/torch_stable.html\"\n",
    "        else:\n",
    "            pip_mirror = \"--index-url https://pypi.python.org/simple --find-links https://download.pytorch.org/whl/cu121/torch_stable.html\"\n",
    "        if os.path.exists(path):\n",
    "            echo(\"安装依赖\")\n",
    "            !pip install -r \"{path}\" {pip_mirror}\n",
    "        else:\n",
    "            echo(\"依赖文件路径为空\")\n",
    "\n",
    "\n",
    "    # python软件包安装\n",
    "    # 可使用的操作:\n",
    "    # 安装: install -> install\n",
    "    # 仅安装: install_single -> install --no-deps\n",
    "    # 强制重装: force_install -> install --force-reinstall\n",
    "    # 仅强制重装: force_install_single -> install --force-reinstall --no-deps\n",
    "    # 更新: update -> install --upgrade\n",
    "    # 卸载: uninstall -y\n",
    "    def py_pkg_manager(self, pkg, type=None, use_mirror=False):\n",
    "        if use_mirror is True:\n",
    "            pip_mirror = \"--index-url https://mirrors.cloud.tencent.com/pypi/simple --find-links https://mirror.sjtu.edu.cn/pytorch-wheels/cu121/torch_stable.html\"\n",
    "        else:\n",
    "            pip_mirror = \"--index-url https://pypi.python.org/simple --find-links https://download.pytorch.org/whl/cu121/torch_stable.html\"\n",
    "\n",
    "        if type == \"install\":\n",
    "            func = \"install\"\n",
    "            args = \"\"\n",
    "        elif type == \"install_single\":\n",
    "            func = \"install\"\n",
    "            args = \"--no-deps\"\n",
    "        elif type == \"force_install\":\n",
    "            func = \"install\"\n",
    "            args = \"--force-reinstall\"\n",
    "        elif type == \"force_install_single\":\n",
    "            func = \"install\"\n",
    "            args = \"install --force-reinstall --no-deps\"\n",
    "        elif type == \"update\":\n",
    "            func = \"install\"\n",
    "            args = \"--upgrade\"\n",
    "        elif type == \"uninstall\":\n",
    "            func = \"uninstall\"\n",
    "            args = \"-y\"\n",
    "            pip_mirror = \"\"\n",
    "        else:\n",
    "            echo(f\"未知操作: {type}\")\n",
    "            return\n",
    "        echo(f\"执行操作: pip {func} {pkg} {args} {pip_mirror}\")\n",
    "        !pip {func} {pkg} {args} {pip_mirror}\n",
    "\n",
    "\n",
    "    # 配置内存优化\n",
    "    def tcmalloc(self):\n",
    "        echo(\"配置内存优化\")\n",
    "        import os\n",
    "        os.environ[\"LD_PRELOAD\"] = \"/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4\"\n",
    "\n",
    "\n",
    "\n",
    "# MANAGER\n",
    "class MANAGER:\n",
    "    WORKSPACE = \"\"\n",
    "    WORKFOLDER = \"\"\n",
    "\n",
    "\n",
    "    def __init__(self, workspace, workfolder) -> None:\n",
    "        self.WORKSPACE = workspace\n",
    "        self.WORKFOLDER = workfolder\n",
    "\n",
    "\n",
    "    # 清理ipynb笔记的输出\n",
    "    def clear_up(self, opt):\n",
    "        from IPython.display import clear_output\n",
    "        clear_output(wait=opt)\n",
    "\n",
    "\n",
    "    # 检查gpu是否可用\n",
    "    def check_gpu(self):\n",
    "        echo(\"检测 GPU 是否可用\")\n",
    "        import tensorflow as tf\n",
    "        echo(f\"TensorFlow 版本: {tf.__version__}\")\n",
    "        if tf.test.gpu_device_name():\n",
    "            echo(\"GPU 可用\")\n",
    "        else:\n",
    "            echo(\"GPU 不可用\")\n",
    "            raise Exception(\"\\n没有使用GPU，请在代码执行程序-更改运行时类型-设置为GPU！\\n如果不能使用GPU，建议更换账号！\")\n",
    "\n",
    "\n",
    "    def select_list(self, data, name):\n",
    "        # https://stackoverflow.com/questions/57219796/ipywidgets-dynamic-creation-of-checkboxes-and-selection-of-data\n",
    "        # https://gist.github.com/MattJBritton/9dc26109acb4dfe17820cf72d82f1e6f\n",
    "        import ipywidgets as widgets\n",
    "        names = [] # 可选择的列表\n",
    "        checkbox_objects = [] # 按钮对象\n",
    "        for key in data:\n",
    "            value = key[1]\n",
    "            key = key[0].split(\"/\").pop()\n",
    "            if value == 1:\n",
    "                select = True\n",
    "            else:\n",
    "                select = False\n",
    "            checkbox_objects.append(widgets.Checkbox(value=select, description=key, )) # 扩展按钮列表\n",
    "            names.append(key)\n",
    "\n",
    "        arg_dict = {names[i]: checkbox for i, checkbox in enumerate(checkbox_objects)}\n",
    "\n",
    "        ui = widgets.VBox(children=checkbox_objects) # 创建widget\n",
    "\n",
    "        selected_data = []\n",
    "        select_value = [] # 存储每个模型选择情况\n",
    "        url_list = [] # 地址列表\n",
    "        def select_data(**kwargs): # 每次点击按钮时都会执行\n",
    "            selected_data.clear()\n",
    "            select_value.clear()\n",
    "            for key in kwargs:\n",
    "                if kwargs[key] is True:\n",
    "                    selected_data.append(key)\n",
    "                    select_value.append(True)\n",
    "                else:\n",
    "                    select_value.append(False)\n",
    "\n",
    "            list = \"\"\n",
    "            for i in selected_data: # 已选择的模型列表(模型名称)\n",
    "                list = f\"{list}\\n- {i}\"\n",
    "            print(f\"已选择列表: {list}\")\n",
    "            j = 0\n",
    "            url_list.clear()\n",
    "            for i in select_value: # 返回的地址列表\n",
    "                if i is True:\n",
    "                    url_list.append(data[j][0])\n",
    "                j += 1\n",
    "        \n",
    "        out = widgets.interactive_output(select_data, arg_dict)\n",
    "        ui.children = [*ui.children, out]\n",
    "        ui = widgets.Accordion(children=[ui,], titles=(name,))\n",
    "        #display(ui, out)\n",
    "        display(ui)\n",
    "        return url_list\n",
    "\n",
    "\n",
    "\n",
    "# SD_TRAINER\n",
    "class SD_TRAINER(ARIA2, GIT, TUNNEL, MANAGER, ENV):\n",
    "    WORKSPACE = \"\"\n",
    "    WORKFOLDER = \"\"\n",
    "\n",
    "    tun = TUNNEL(WORKSPACE, WORKFOLDER, 28000)\n",
    "\n",
    "    def __init__(self, workspace, workfolder) -> None:\n",
    "        self.WORKSPACE = workspace\n",
    "        self.WORKFOLDER = workfolder\n",
    "\n",
    "\n",
    "    def get_sd_model(self, url, filename = None):\n",
    "        path = self.WORKSPACE + \"/\" + self.WORKFOLDER + \"/sd-models\"\n",
    "        filename = url.split(\"/\").pop() if filename is None else filename\n",
    "        super().aria2(url, path, filename)\n",
    "\n",
    "\n",
    "    def get_vae_model(self, url, filename = None):\n",
    "        path = self.WORKSPACE + \"/\" + self.WORKFOLDER + \"/sd-models\"\n",
    "        filename = url.split(\"/\").pop() if filename is None else filename\n",
    "        super().aria2(url, path, filename)\n",
    "\n",
    "\n",
    "    def get_sd_model_from_list(self, list):\n",
    "        for i in list:\n",
    "            if i != \"\":\n",
    "                self.get_sd_model(i, i.split(\"/\").pop())\n",
    "\n",
    "\n",
    "    def get_vae_model_from_list(self, list):\n",
    "        for i in list:\n",
    "            if i != \"\":\n",
    "                self.get_vae_model(i, i.split(\"/\").pop())\n",
    "\n",
    "\n",
    "    def install_kohya_requirements(self):\n",
    "        import os\n",
    "        os.chdir(WORKSPACE + \"/\" + WORKFOLDER + \"/sd-scripts\")\n",
    "        self.install_requirements(self.WORKSPACE + \"/\" + self.WORKFOLDER + \"/sd-scripts/requirements.txt\")\n",
    "        os.chdir(WORKSPACE + \"/\" + WORKFOLDER)\n",
    "\n",
    "\n",
    "    def fix_lang(self):\n",
    "        # ???\n",
    "        import os\n",
    "        os.environ[\"LANG\"] = \"zh_CN.UTF-8\"\n",
    "\n",
    "\n",
    "    def fix_py_package(self, package, use_mirror):\n",
    "        !rm -rf /opt/conda/lib/python3.10/site-packages/{package}\n",
    "        !rm -rf /opt/conda/lib/python3.10/site-packages/{package}-*\n",
    "        self.py_pkg_manager(package, \"uninstall\", use_mirror)\n",
    "        self.py_pkg_manager(package, \"install\", use_mirror)\n",
    "\n",
    "\n",
    "    def install(self, torch_ver, xformers_ver, sd, vae, use_mirror):\n",
    "        import os\n",
    "        os.chdir(self.WORKSPACE)\n",
    "        self.check_gpu()\n",
    "        self.prepare_env_depend(use_mirror)\n",
    "        self.clone(\"https://github.com/Akegarasu/lora-scripts\", self.WORKSPACE)\n",
    "        os.chdir(f\"{self.WORKSPACE}/{self.WORKFOLDER}\")\n",
    "        self.prepare_torch(torch_ver, xformers_ver)\n",
    "        req_file = self.WORKSPACE + \"/\" + self.WORKFOLDER + \"/requirements.txt\"\n",
    "        self.fix_py_package(\"aiohttp\", use_mirror)\n",
    "        # self.install_kohya_requirements()\n",
    "        self.install_requirements(req_file, use_mirror)\n",
    "        self.py_pkg_manager(\"protobuf==3.20.0\", \"install\", use_mirror)\n",
    "        self.py_pkg_manager(\"numpy==1.26.4\", \"force_install\", use_mirror)\n",
    "        self.fix_lang()\n",
    "        self.tcmalloc()\n",
    "        self.get_sd_model_from_list(sd)\n",
    "        self.get_vae_model_from_list(vae)\n",
    "\n",
    "#############################################################\n",
    "\n",
    "echo(\"初始化功能完成\")\n",
    "try:\n",
    "    echo(\"尝试安装 ipywidgets 组件\")\n",
    "    !pip install ipywidgets -qq\n",
    "    from IPython.display import clear_output\n",
    "    clear_output(wait=False)\n",
    "    INIT_CONFIG = 1\n",
    "except:\n",
    "    raise Exception(\"未初始化功能\")\n",
    "\n",
    "import ipywidgets as widgets\n",
    "\n",
    "\n",
    "WORKSPACE = \"/kaggle\"\n",
    "WORKFOLDER = \"lora-scripts\"\n",
    "USE_NGROK = False\n",
    "NGROK_TOKEN = \"\"\n",
    "USE_CLOUDFLARE = False\n",
    "USE_REMOTE_MOE = True\n",
    "USE_LOCALHOST_RUN = True\n",
    "USE_GRADIO_SHARE = False\n",
    "TORCH_VER = \"\"\n",
    "XFORMERS_VER = \"\"\n",
    "sd_model = [ # 前面为 Stable Diffusion 模型的下载链接, 后面为 1 时将会下载该模型\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/v1-5-pruned-emaonly.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/animefull-final-pruned.safetensors\", 1],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/Counterfeit-V3.0_fp16.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/cetusMix_Whalefall2.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/cuteyukimixAdorable_neochapter3.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/ekmix-pastel-fp16-no-ema.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/ex2K_sse2.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/kohakuV5_rev2.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/meinamix_meinaV11.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/oukaStar_10.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/pastelMixStylizedAnime_pastelMixPrunedFP16.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/rabbit_v6.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/sweetSugarSyndrome_rev15.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/AnythingV5Ink_ink.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/bartstyledbBlueArchiveArtStyleFineTunedModel_v10.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/meinapastel_v6Pastel.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/qteamixQ_omegaFp16.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/tmndMix_tmndMixSPRAINBOW.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/sd_xl_base_1.0_0.9vae.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/animagine-xl-3.0.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/AnythingXL_xl.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/abyssorangeXLElse_v10.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/animaPencilXL_v200.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/animagine-xl-3.1.safetensors\", 1],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/heartOfAppleXL_v20.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/baxlBartstylexlBlueArchiveFlatCelluloid_xlv1.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/kohaku-xl-delta-rev1.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/kohakuXLEpsilon_rev1.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/kohaku-xl-epsilon-rev3.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/kohaku-xl-zeta.safetensors\", 1],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/nekorayxl_v06W3.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/CounterfeitXL-V1.0.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/ponyDiffusionV6XL_v6StartWithThisOne.safetensors\", 0]\n",
    "]\n",
    "vae = [\n",
    "    [\"https://huggingface.co/licyk/sd-vae/resolve/main/sd_1.5/vae-ft-ema-560000-ema-pruned.safetensors\", 0],\n",
    "    [\"https://huggingface.co/licyk/sd-vae/resolve/main/sd_1.5/vae-ft-mse-840000-ema-pruned.safetensors\", 1],\n",
    "    [\"https://huggingface.co/licyk/sd-vae/resolve/main/sdxl_1.0/sdxl_fp16_fix_vae.safetensors\", 1]\n",
    "]\n",
    "\n",
    "manager = MANAGER(WORKSPACE, WORKFOLDER)\n",
    "\n",
    "torch_ver_state = widgets.Textarea(value=\"torch==2.3.0+cu121 torchvision==0.18.0+cu121 torchaudio==2.3.0+cu121\", placeholder=\"请填写 PyTorch 版本\", description=\"PyTorch 版本: \", disabled=False)\n",
    "xformers_ver_state = widgets.Textarea(value=\"xformers==0.0.26.post1\", placeholder=\"请填写 xFormers 版本\", description=\"xFormers 版本: \", disabled=False)\n",
    "use_ngrok_state = widgets.Checkbox(value=False, description=\"使用 Ngrok 内网穿透\", disabled=False)\n",
    "ngrok_token_state = widgets.Textarea(value=\"\", placeholder=\"请填写 Ngrok Token（可在 Ngrok 官网获取）\", description=\"Ngrok Token: \", disabled=False)\n",
    "use_cloudflare_state = widgets.Checkbox(value=False, description=\"使用 CloudFlare 内网穿透\", disabled=False)\n",
    "use_remote_moe_state = widgets.Checkbox(value=True, description=\"使用 remote.moe 内网穿透\", disabled=False)\n",
    "use_localhost_run_state = widgets.Checkbox(value=True, description=\"使用 localhost.run 内网穿透\", disabled=False)\n",
    "use_gradio_share_state = widgets.Checkbox(value=False, description=\"使用 Gradio 内网穿透\", disabled=False)\n",
    "# 自定义模型下载\n",
    "model_url = widgets.Textarea(value=\"\", placeholder=\"请填写模型下载链接\", description=\"模型链接: \", disabled=False)  #@param {type:\"string\"}\n",
    "model_name = widgets.Textarea(value=\"\", placeholder=\"请填写模型名称，包括后缀名，例：kohaku-xl.safetensors\", description=\"模型名称: \", disabled=False)  #@param {type:\"string\"}\n",
    "model_type = widgets.Dropdown(options=[(\"Stable Diffusion 模型（大模型）\", \"sd\"), (\"VAE 模型\", \"vae\")], value=\"sd\", description='模型种类: ')\n",
    "\n",
    "display(torch_ver_state, xformers_ver_state, use_ngrok_state, ngrok_token_state, use_cloudflare_state, use_remote_moe_state, use_localhost_run_state, use_gradio_share_state, model_name, model_url, model_type)\n",
    "\n",
    "\n",
    "sd_model_list = manager.select_list(sd_model,\"Stable Diffusion 模型\")\n",
    "vae_list = manager.select_list(vae, \"VAE 模型\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 安装"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    i = INIT_CONFIG\n",
    "    INIT_CONFIG_1 = 1\n",
    "except:\n",
    "    raise Exception(\"未运行\\\"参数配置\\\"单元\")\n",
    "\n",
    "TORCH_VER = torch_ver_state.value\n",
    "XFORMERS_VER = xformers_ver_state.value\n",
    "USE_NGROK = use_ngrok_state.value\n",
    "NGROK_TOKEN = ngrok_token_state.value\n",
    "USE_CLOUDFLARE = use_cloudflare_state.value\n",
    "USE_REMOTE_MOE = use_remote_moe_state.value\n",
    "USE_LOCALHOST_RUN = use_localhost_run_state.value\n",
    "USE_GRADIO_SHARE = use_gradio_share_state.value\n",
    "USE_MIRROR = False\n",
    "\n",
    "sd_trainer = SD_TRAINER(WORKSPACE, WORKFOLDER)\n",
    "\n",
    "import os\n",
    "os.chdir(WORKSPACE)\n",
    "\n",
    "echo(f\"开始安装 SD Trainer\")\n",
    "sd_trainer.install(TORCH_VER, XFORMERS_VER, sd_model_list, vae_list, USE_MIRROR)\n",
    "if model_url.value != \"\" and model_name.value != \"\":\n",
    "    if model_type.value == \"sd\":\n",
    "        sd_trainer.get_sd_model(model_url.value, model_name.value)\n",
    "    elif model_type.value == \"vae\":\n",
    "        sd_trainer.get_vae_model(model_url.value, model_name.value)\n",
    "cp_data()\n",
    "sd_trainer.clear_up(False)\n",
    "echo(f\"SD Trainer 安装完成\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 启动"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    i = INIT_CONFIG_1\n",
    "except:\n",
    "    raise Exception(\"未运行\\\"安装\\\"单元\")\n",
    "\n",
    "import os\n",
    "os.chdir(WORKSPACE + \"/\" + WORKFOLDER)\n",
    "echo(\"启动 SD Trainer 中\")\n",
    "sd_trainer.tun.start(ngrok=USE_NGROK, ngrok_token=NGROK_TOKEN, cloudflare=USE_CLOUDFLARE, remote_moe=USE_REMOTE_MOE, localhost_run=USE_LOCALHOST_RUN, gradio=USE_GRADIO_SHARE)\n",
    "!python \"{WORKSPACE}\"/lora-scripts/gui.py\n",
    "sd_trainer.clear_up(False)\n",
    "echo(\"SD Trainer 已关闭\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
